from xmlrpc.client import boolean
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Union, Iterable, Sized, Tuple
import itertools

import matplotlib.pyplot as plt

def truncated_normal_(tensor, mean: float = 0., std: float = 1.):  
    size = tensor.shape
    tmp = tensor.new_empty(size + (4,)).normal_()
    valid = (tmp < 2) & (tmp > -2)
    ind = valid.max(-1, keepdim=True)[1]
    tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
    tensor.data.mul_(std).add_(mean)


class ActivationLayer(torch.nn.Module):
    def __init__(self,
                in_features: int,
                out_features: int):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.empty((in_features, out_features)))
        self.bias = torch.nn.Parameter(torch.empty(out_features))

    def forward(self, x):
        raise NotImplementedError("abstract methodd called")

class ExpUnit(ActivationLayer):
    def __init__(self,
                in_features: int,
                out_features: int,
                activation: str = 'relu_n'):
        super().__init__(in_features, out_features)
        torch.nn.init.uniform_(self.weight,a=-20.0, b=2.0)
        truncated_normal_(self.bias, std=0.5)
        self.size = in_features
        self.activation = activation
        
    def forward(self, x):

        out = (x) @ torch.exp(self.weight) + self.bias
        
        if self.activation == 'relu_n':
            out = (1-0.01) * torch.clip(out, 0, 1) + 0.01 * out 
        elif self.activation == 'sigmoid':
            out = F.sigmoid(out)
        elif self.activation == 'tanh':
            out = F.tanh(out)
        return out
    
    
    
    
class ConvUnit(ActivationLayer):
    def __init__(self,
                in_features: int,
                out_features: int,
                activation: str = 'relu'):
        super().__init__(in_features, out_features)
        torch.nn.init.uniform_(self.weight,a=-20.0, b=2.0)
        truncated_normal_(self.bias, std=0.5)
        self.size = in_features
        self.activation = activation
        
    def forward(self, x):

        
        out = (x) @ torch.exp(self.weight) + self.bias

        if self.activation == 'relu':
            out = F.relu(out) 
        elif self.activation == 'elu':
            out = F.elu(out,alpha=0.01)
        elif self.activation == 'leaky_Relu':
            out = F.leaky_relu(out,negative_slope=0.01)
        elif self.activation == 'softplus':
            out = F.softplus(out)    
        
        return out
    

class ConcUnit(ActivationLayer):
    def __init__(self,
                in_features: int,
                out_features: int,
                activation: str = 'relu'):
        super().__init__(in_features, out_features)
        torch.nn.init.uniform_(self.weight,a=-20.0, b=2.0)
        truncated_normal_(self.bias, std=0.5)
        self.size = in_features
        self.activation = activation
        
    def forward(self, x):
        
        out = (x) @ torch.exp(self.weight) + self.bias
        out = out*(-1)
        
        if self.activation == 'relu':
            out = F.relu(out)*(-1)
        elif self.activation == 'elu':
            out = F.elu(out, alpha=0.01)*(-1)
        elif self.activation == 'leaky_Relu':
            out = F.leaky_relu(out,negative_slope=0.01)*(-1)
        elif self.activation == 'softplus':
            out = F.softplus(out)*(-1)   
        return out
    
class ReLUUnit(ActivationLayer):
    def __init__(self,
                in_features: int,
                out_features: int,
                activation: str = 'relu'):
        super().__init__(in_features, out_features)
        torch.nn.init.xavier_uniform_(self.weight)
        truncated_normal_(self.bias, std=0.5)
        self.size = in_features
        self.activation = activation

    def forward(self, x):

        out = (x) @ self.weight + self.bias
        
        if self.activation == 'relu':
            out = F.relu(out)
        elif self.activation == 'elu':
            out = F.elu(out,alpha=0.01)
        elif self.activation == 'leaky_Relu':
            out = F.leaky_relu(out,negative_slope=0.01)
        elif self.activation == 'softplus':
            out = F.softplus(out)   
        return out


class RefReLUUnit(ActivationLayer):
    def __init__(self,
                in_features: int,
                out_features: int,
                activation: str = 'relu'):
        super().__init__(in_features, out_features)
        torch.nn.init.xavier_uniform_(self.weight)
        truncated_normal_(self.bias, std=0.5)
        self.activation = activation

    def forward(self, x):

        out = x @ self.weight + self.bias
        out = out*(-1)
        
        if self.activation == 'relu':
            out = F.relu(out)*(-1)
        elif self.activation == 'elu':
            out = F.elu(out,alpha=0.01)*(-1)
        elif self.activation == 'leaky_Relu':
            out = F.leaky_relu(out,negative_slope=0.01)*(-1)
        elif self.activation == 'softplus':
            out = F.softplus(out)*(-1)       
        return out


class FCLayer(ActivationLayer):
    def __init__(self,
                in_features: int,
                out_features: int):
        super().__init__(in_features, out_features)
        truncated_normal_(self.weight, mean=-10.0, std=3)
        truncated_normal_(self.bias, std=0.5)

    def forward(self, x):

        FC = x @ torch.exp(self.weight) + self.bias
        return FC

class FCLayer_notexp(ActivationLayer):
    def __init__(self,
                in_features: int,
                out_features: int):
        super().__init__(in_features, out_features)
        torch.nn.init.xavier_uniform_(self.weight)
        truncated_normal_(self.bias, std=0.5)

    def forward(self, x):

        FC = x @ self.weight + self.bias
        return FC

class ExpBatchNorm1d(nn.Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        super(ExpBatchNorm1d, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum

        # learnable parameters
        self.gamma = nn.Parameter(torch.zeros(num_features))  
        self.beta = nn.Parameter(torch.zeros(num_features))  
        
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))

    def forward(self, x):
            # device = x.device  
            # self.running_mean = self.running_mean.to(device)  
            # self.running_var = self.running_var.to(device) 
            
        
        
        if self.training:
            mean = x.mean(dim=0) 
            var = x.var(dim=0, unbiased=False)  
            mean = mean#.to(device)
            var = var#.to(device)
            
            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
        else:
            mean = self.running_mean
            var = self.running_var

        
        gamma = torch.exp(self.gamma)#.to(device)  
        beta = self.beta#.to(device)  

        # Batch Normalization 
        x_normalized = (x - mean) / torch.sqrt(var + self.eps)
        out = gamma * x_normalized + beta

        return out




class MultiLayerPerceptron(torch.nn.Module):
    def __init__(self,
                input_size: int, 
                hidden_size: Tuple = (),
                hidden_layer: ActivationLayer = ReLUUnit,
                fully_connected_layer: ActivationLayer = FCLayer_notexp):

        super(MultiLayerPerceptron,self).__init__()
        
        self.input_size = input_size
        self.hidden_size = hidden_size   

        self.hidden_layers = torch.nn.ModuleList([
            hidden_layer(self.input_size if i == 0 else hidden_size[i-1], hidden_size[i])
            for i in range(len(hidden_size))
        ])

        self.fclayer = fully_connected_layer(hidden_size[len(hidden_size)-1] ,1)

    def forward(self,x):

        for i in range(len(self.hidden_size)):
            if i == 0 :
                output = self.hidden_layers[i](x)
            else :
                output = self.hidden_layers[i](output)          

        out = self.fclayer(output)
        return out




class InteractionLayer(torch.nn.Module):
    def __init__(self, conv_dim, conc_dim):
        """
        conv_dim : convex (x_conv + x_monoconv)
        conc_dim : concave (x_conc + x_monoconc)
        """
        super(InteractionLayer, self).__init__()
        self.alpha = nn.Parameter(torch.empty(conv_dim, conc_dim))
        nn.init.xavier_uniform_(self.alpha)

    def forward(self, conv, conc):
        """
        conv : tensor, shape (batch_size, conv_dim)
        conc : tensor, shape (batch_size, conc_dim)
        
        interaction = sum_{i,j} (conv_i * conc_j * alpha_{ij})
        """

        interaction = (torch.matmul(conv, self.alpha) * conc).sum(dim=1, keepdim=True)
        return interaction




class COMONet(torch.nn.Module):
    def __init__(self,
                input_size: int,
                conv_features: list,
                monoconv_features: list,
                conc_features: list,
                monoconc_features: list,
                mono_features: list,
                conv_layer_size: Tuple = (),
                monoconv_layer_size: Tuple = (),
                conc_layer_size: Tuple = (),
                monoconc_layer_size: Tuple = (),
                mono_layer_size: Tuple = (),
                unconst_layer_size: Tuple = (),
                batch_norm : bool = True,
                exp_unit: ActivationLayer = ExpUnit, # Relu-n(exp(W)*X+b)
                conv_unit: ActivationLayer = ConvUnit, # Relu(exp(W)*X+b)
                conc_unit: ActivationLayer = ConcUnit, # -Relu(-exp(W)*X+b)
                relu_unit: ActivationLayer = ReLUUnit, # Relu(W*X+b)
                ref_relu_unit: ActivationLayer = RefReLUUnit, # Relu(W*X+b)
                fully_connected_layer: ActivationLayer = FCLayer,
                interaction_layer: nn.Module  = InteractionLayer,
                exp_batchnorm1d: nn.Module = ExpBatchNorm1d,
                activation: list =['relu','relu_n']):
        super(COMONet,self).__init__()

        self.input_size = input_size
        self.conv_size = len(conv_features)
        self.conc_size = len(conc_features)
        self.monoconv_size = len(monoconv_features)
        self.monoconc_size = len(monoconc_features)
        self.mono_size = len(mono_features)
        
        self.unconst_size = input_size - self.conv_size - self.monoconv_size - self.conc_size - self.monoconc_size - self.mono_size
        
        self.conv_features = conv_features
        self.monoconv_features = monoconv_features
        self.conc_features = conc_features
        self.monoconc_features = monoconc_features
        self.mono_features = mono_features
        self.unconst_features = list(set(list(range(input_size))).difference(conv_features,monoconv_features,conc_features,monoconc_features,mono_features))
        
        self.conv_layer_size = conv_layer_size  
        self.monoconv_layer_size = monoconv_layer_size  
        self.conc_layer_size = conc_layer_size  
        self.monoconc_layer_size = monoconc_layer_size  
        self.mono_layer_size = mono_layer_size  
        self.unconst_layer_size = unconst_layer_size  
        self.batch_norm = batch_norm
        self.activation = activation
        
        self.conv_layers = torch.nn.ModuleList([
            relu_unit(self.conv_size, conv_layer_size[i],activation = self.activation[0]) if i == 0 else conv_unit(
                conv_layer_size[i - 1] + monoconv_layer_size[i - 1] + mono_layer_size[i - 1] + unconst_layer_size[i - 1], conv_layer_size[i],activation = self.activation[0]) if i == 1 else conv_unit(
                    conv_layer_size[i - 1] + mono_layer_size[i - 1] + unconst_layer_size[i - 1], conv_layer_size[i],activation = self.activation[0])
            for i in range(len(conv_layer_size))
        ])
        
        self.monoconv_layers = torch.nn.ModuleList([
            conv_unit(self.monoconv_size, monoconv_layer_size[0],activation = self.activation[0])
        ])

        self.conc_layers = torch.nn.ModuleList([
            ref_relu_unit(self.conc_size, conc_layer_size[i],activation = self.activation[0]) if i == 0 else conc_unit(
                conc_layer_size[i - 1] + monoconc_layer_size[i - 1] + mono_layer_size[i - 1] + unconst_layer_size[i - 1], conc_layer_size[i],activation = self.activation[0]) if i == 1 else conc_unit(
                    conc_layer_size[i - 1] + mono_layer_size[i - 1] + unconst_layer_size[i - 1], conc_layer_size[i],activation = self.activation[0])
            for i in range(len(conc_layer_size))
        ])
        
        self.monoconc_layers = torch.nn.ModuleList([
            conc_unit(self.monoconc_size, monoconc_layer_size[0],activation = self.activation[0])
        ])  

        self.mono_layers = torch.nn.ModuleList([
            exp_unit(self.mono_size if i == 0 else mono_layer_size[i-1] + unconst_layer_size[i - 1], mono_layer_size[i],activation = self.activation[1])
            for i in range(len(mono_layer_size))
        ])

        self.unconst_layers = torch.nn.ModuleList([
            relu_unit(self.unconst_size if i == 0 else unconst_layer_size[i-1], unconst_layer_size[i],activation = self.activation[0])
            for i in range(len(unconst_layer_size))
        ])

        if self.batch_norm == True:
            self.conv_norms = torch.nn.ModuleList([
                exp_batchnorm1d(conv_layer_size[i]) for i in range(len(self.conv_layer_size))
            ])
            self.monoconv_norms = torch.nn.ModuleList([
                exp_batchnorm1d(monoconv_layer_size[0])
            ])
            
            self.conc_norms = torch.nn.ModuleList([
                exp_batchnorm1d(conc_layer_size[i]) for i in range(len(self.conc_layer_size))
            ])
            self.monoconc_norms = torch.nn.ModuleList([
                exp_batchnorm1d(monoconc_layer_size[0])
            ])
            
            self.mono_norms = torch.nn.ModuleList([
                exp_batchnorm1d(mono_layer_size[i]) for i in range(len(self.mono_layers))
            ])
            self.unconst_norms = torch.nn.ModuleList([
                exp_batchnorm1d(unconst_layer_size[i]) for i in range(len(self.unconst_layers))
            ])

        self.fclayer = fully_connected_layer(conv_layer_size[len(conv_layer_size)-1] 
                                            + conc_layer_size[len(conc_layer_size)-1]
                                            + mono_layer_size[len(mono_layer_size)-1] #+ conf_layer_size[len(conf_layer_size)-1]
                                            + unconst_layer_size[len(unconst_layer_size)-1],1)

        if (self.conv_size + self.monoconv_size) != 0 and (self.conc_size + self.monoconc_size) != 0:
            self.interaction_layer = interaction_layer(self.conv_size + self.monoconv_size, 
                                                        self.conc_size + self.monoconc_size)
        else:
            self.interaction_layer = None


    def forward(self,x):

        x_conv  = x[:, self.conv_features]
        x_monoconv = x[:, self.monoconv_features]
        x_conc  = x[:, self.conc_features]
        x_monoconc = x[:, self.monoconc_features]
        x_mono  = x[:, self.mono_features]
        x_unconst  = x[:, self.unconst_features]

        for i in range(len(self.conv_layer_size)):
            if i == 0 :
                conv_output = self.conv_layers[i](x_conv)
                monoconv_output = self.monoconv_layers[i](x_monoconv)
                conc_output = self.conc_layers[i](x_conc)
                monoconc_output = self.monoconc_layers[i](x_monoconc)
                mono_output = self.mono_layers[i](x_mono)
                unconst_output = self.unconst_layers[i](x_unconst)
                
                # # batch normalization
                # if self.batch_norm == True:
                #     conv_output = self.conv_norms[i](conv_output)
                #     monoconv_output = self.monoconv_norms[i](monoconv_output)
                #     conc_output = self.conc_norms[i](conc_output)
                #     monoconc_output = self.monoconc_norms[i](monoconc_output)
                #     mono_output = self.mono_norms[i](mono_output)
                #     unconst_output = self.unconst_norms[i](unconst_output)

                conv_output = torch.cat([conv_output, monoconv_output, mono_output, unconst_output], dim=1)
                conc_output = torch.cat([conc_output, monoconc_output, mono_output, unconst_output], dim=1)
                mono_output = torch.cat([mono_output, unconst_output], dim=1)
                
            else :
                conv_output = self.conv_layers[i](conv_output)
                conc_output = self.conc_layers[i](conc_output)
                mono_output = self.mono_layers[i](mono_output)
                unconst_output = self.unconst_layers[i](unconst_output)
                
                # # batch normalization
                # if self.batch_norm == True:
                #     conv_output = self.conv_norms[i](conv_output)
                #     conc_output = self.conc_norms[i](conc_output)
                #     mono_output = self.mono_norms[i](mono_output)
                #     unconst_output = self.unconst_norms[i](unconst_output)

                
                if i != (len(self.conv_layer_size)-1):    
                    conv_output = torch.cat([conv_output, mono_output, unconst_output], dim=1)
                    conc_output = torch.cat([conc_output, mono_output, unconst_output], dim=1)
                    mono_output = torch.cat([mono_output, unconst_output], dim=1)
                else:
                    continue

        out = self.fclayer(torch.cat([conv_output,conc_output,mono_output,unconst_output],dim = 1)) 
                
        conv = torch.cat([x_conv, x_monoconv], dim=1)  # shape: (batch_size, conv_dim)
        conc = torch.cat([x_conc, x_monoconc], dim=1)    # shape: (batch_size, conc_dim)
        
        interaction = 0  
        if self.interaction_layer is not None:
            interaction = self.interaction_layer(conv, conc)  # shape: (batch_size, 1)

        out = out + interaction
        
        return out


 